Introduction

sctransform::vst operates under the assumption that gene counts approximately follow a Negative Binomial dristribution. For UMI-based data that seems to be the case, however, non-UMI data does not behave the same way. In some cases it might be better to to apply a transformation to such data to make it look like UMI data. Here we learn such a transformation function.

We take publicly available 10x data and learn the relationship between count distribution quantile per cell (cumulative distribution function) and the observed log-counts. Only non-zero counts are considered and cells are grouped by the number of genes detected. The learned relationship can later be used to shoehorn any cell data (e.g. read counts from non-UMI technologies) into a UMI-like distribution (a.k.a. quantile normalization).

The general idea is very much inspired by Townes & Irizarry, Genome Biology, 2020. There the authors present a quantile normalization approach that uses the Poisson-lognormal distribution to model UMI data.

Load training data

We are going to use various data sets from 10x to learn the quantile-to-log-counts relationship. We don’t want to use all cells, but cells with a diverse range of total number of genes detected (minimum 200), so we first group by number of genes detected and then sample some cells per group.

set.seed(85848484)

h5_files <- list.files(path = file.path("~/Projects/data_warehouse/raw_public_10x"), 
    pattern = ".*\\.h5$", full.names = TRUE)

# constants used for selecting training cells
G <- 33  # number of groups, see below
N <- 111  # number of cells per group, also see below

counts_list <- lapply(1:length(h5_files), function(h5i) {
    message(basename(h5_files[h5i]))
    counts <- Seurat::Read10X_h5(filename = h5_files[h5i])
    if (is.list(counts)) {
        counts <- counts[["Gene Expression"]]
    }
    # sample cells by genes detected group
    keep <- data.frame(idx = 1:ncol(counts), genes = colSums(counts > 0)) %>% filter(genes >= 
        200) %>% mutate(grp = cut(log10(genes), G)) %>% group_by(grp) %>% sample_n(size = min(n(), 
        N)) %>% pull(idx)
    
    # for each cell we keep a table of counts
    h5_counts_list <- lapply(keep, function(i) {
        y <- counts[, i]
        y <- y[y > 0]
        tbl <- data.frame(table(y))
        colnames(tbl) <- c("y", "N")
        tbl$y <- as.numeric(as.character(tbl$y))
        tbl$file <- h5i
        tbl$cell <- i
        tbl
    })
    counts_df <- do.call(rbind, h5_counts_list)
    counts_df
})
#> 5k_pbmc_NGSC3_aggr_filtered_feature_bc_matrix.h5
#> heart_10k_v3_filtered_feature_bc_matrix.h5
#> malt_10k_protein_v3_filtered_feature_bc_matrix.h5
#> Genome matrix has multiple modalities, returning a list of matrices for this genome
#> Parent_NGSC3_DI_HodgkinsLymphoma_filtered_feature_bc_matrix.h5
#> Parent_NGSC3_DI_PBMC_filtered_feature_bc_matrix.h5
#> Parent_SC3v3_Human_Glioblastoma_filtered_feature_bc_matrix.h5
#> pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5
#> Genome matrix has multiple modalities, returning a list of matrices for this genome
#> pbmc_10k_v3_filtered_feature_bc_matrix.h5
#> sc5p_v2_hs_melanoma_10k_filtered_feature_bc_matrix.h5
#> sc5p_v2_hs_PBMC_10k_filtered_feature_bc_matrix.h5
#> Genome matrix has multiple modalities, returning a list of matrices for this genome
#> vdj_v1_hs_nsclc_5gex_filtered_gene_bc_matrices_h5.h5
counts_df <- do.call(rbind, counts_list)

Show what the training data looks like at this point. First 100 rows. y are counts, N are the number of observations (number of genes).

DT::datatable(counts_df[1:100, ], rownames = FALSE)

Remove outliers

Outlier cells are those that do not fit the total counts vs genes detected relationship that most cells show. We use loess to fit the general trend and flag the cells that have a high residual.

# cell attributes
ca <- group_by(counts_df, file, cell) %>% summarise(total = sum(y * N), genes = sum(N), 
    .groups = "drop") %>% mutate(log_total = log10(total), log_genes = log10(genes))
ggplot(ca, aes(log_total, log_genes)) + geom_point(alpha = 0.1, shape = 16) + geom_smooth()
#> `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

ca$is_outlier1 <- abs(scale(residuals(loess(log_genes ~ log_total, data = ca)))) > 2.5
ca$is_outlier2 <- abs(scale(residuals(loess(log_total ~ log_genes, data = ca)))) > 2.5
ca$is_outlier <- ca$is_outlier1 & ca$is_outlier2
ggplot(ca, aes(log_total, log_genes, color = is_outlier)) + geom_point(shape = 16)

ca <- filter(ca, !is_outlier) %>% select(1:6)
counts_df <- left_join(ca, counts_df, by = c("file", "cell"))

The final training data consists of 27164 cells from 11 files.

Fit the relationship

First add the log-transformed count - our target variable. Then, per cell, turn the counts into distribution quantiles (q).

counts_df <- group_by(counts_df, file, cell) %>% mutate(log_y = log10(y), q = cumsum(N/sum(N)))

Group the cells based on how many genes were detected. For this we log10-transform the number of genes detected and create 20 equally spaced bins. Later, when the fit is used for prediction, a linear interpolation between the prediction from the two closest groups is done.

K <- 20  # number of groups
w <- diff(range(counts_df$log_genes)) * (1 + 1e-06)/K
breaks <- min(counts_df$log_genes) + (0:K) * w
counts_df$grp = cut(counts_df$log_genes, breaks = breaks, right = FALSE)
tab <- select(counts_df, file, cell, grp) %>% distinct() %>% pull(grp) %>% table()
tab <- data.frame(tab)
colnames(tab) <- c("Genes detected [log10]", "Cells")
knitr::kable(tab)
Genes detected [log10] Cells
[2.3,2.39) 366
[2.39,2.47) 540
[2.47,2.56) 781
[2.56,2.64) 1010
[2.64,2.73) 1035
[2.73,2.82) 1239
[2.82,2.9) 1519
[2.9,2.99) 1865
[2.99,3.07) 1964
[3.07,3.16) 2128
[3.16,3.25) 2141
[3.25,3.33) 2177
[3.33,3.42) 2094
[3.42,3.5) 2089
[3.5,3.59) 1958
[3.59,3.67) 1731
[3.67,3.76) 1237
[3.76,3.85) 791
[3.85,3.93) 387
[3.93,4.02) 112

Show first 100 rows of training data

DT::datatable(counts_df[1:100, ], rownames = FALSE) %>% DT::formatRound(c(5, 6, 9, 10), 
    digits = 2)

Fit trend per group with loess.

models <- mutate(counts_df, log_y = log10(y)) %>% group_by(grp) %>% do(fit = loess(log_y ~ 
    q, data = ., span = 0.2, degree = 1))

Show fit per group

gdf <- group_by(models, grp) %>% do(purrr::map_dfr(.x = .$fit, .f = function(x) data.frame(q = x$x, 
    fitted_log_y = x$fitted, log_y = x$y)))
filter(gdf, fitted_log_y >= 0) %>% ggplot(aes(q, fitted_log_y, color = grp)) + geom_line() + 
    guides(color = guide_legend(ncol = 2)) + ggtitle("Expected counts [log10] as function of quantile\nGrouped by number of genes detected [log10]")

Zoom into right part of plot

filter(gdf, fitted_log_y >= 0, q >= 0.93) %>% ggplot(aes(q, fitted_log_y, color = grp)) + 
    geom_line() + guides(color = guide_legend(ncol = 2))

Create final model and save

We could now use the loess models to apply the transformation function to new data (calculate quantiles for the new data, then use predict.loess). However, the loess models are quite large, since all the training data is part of the models. We want to distribute the models with the package, so we boil them down here.

We save the predicted log-counts for 512 evenly spaced quantile values. This is enough information to later use linear interpolation for simple and fast predictions of log-counts given quantile scores and the numnber of genes detected.

q_out <- seq(min(counts_df$q), max(counts_df$q), length = 512)
fit_df <- group_by(models, grp) %>% do(purrr::map_dfr(.x = .$fit, .f = function(x) data.frame(q = q_out, 
    log_y = predict(x, newdata = q_out)))) %>% filter(log_y >= 0)
umify_data <- list(fit_df = data.frame(fit_df), grp_breaks = breaks)
save(umify_data, file = "../data/umify_data.rda")

Show first and last 6 rows of final model

DT::datatable(head(fit_df, 6), rownames = FALSE)
DT::datatable(tail(fit_df, 6), rownames = FALSE)

3D visualization of final model

fit_df <- group_by(models, grp) %>% do(purrr::map_dfr(.x = .$fit, .f = function(x) data.frame(q = q_out, 
    log_y = predict(x, newdata = q_out))))
xo <- q_out
yo <- head(breaks, -1) + w/2
plotly::plot_ly(x = xo, y = yo, z = t(matrix(fit_df$log_y, nrow = length(xo))), type = "surface") %>% 
    plotly::layout(title = "UMI count as function of quantile and number of genes detected", 
        scene = list(camera = list(eye = list(x = -1, y = -2.25, z = 1)), xaxis = list(title = "Quantile"), 
            yaxis = list(title = "Genes [log10]"), zaxis = list(title = "UMI count [log10]")))

Session info and runtime

Session info

sessionInfo()
#> R version 4.0.2 (2020-06-22)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Catalina 10.15.7
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] dplyr_1.0.2   knitr_1.30    ggplot2_3.3.2 Matrix_1.2-18
#> 
#> loaded via a namespace (and not attached):
#>   [1] nlme_3.1-149           matrixStats_0.57.0     bit64_4.0.5           
#>   [4] RcppAnnoy_0.0.16       RColorBrewer_1.1-2     httr_1.4.2            
#>   [7] sctransform_0.3.2.9003 tools_4.0.2            DT_0.16               
#>  [10] R6_2.5.0               irlba_2.3.3            rpart_4.1-15          
#>  [13] KernSmooth_2.23-17     uwot_0.1.8.9001        mgcv_1.8-33           
#>  [16] lazyeval_0.2.2         colorspace_2.0-0       withr_2.3.0           
#>  [19] tidyselect_1.1.0       gridExtra_2.3          bit_4.0.4             
#>  [22] compiler_4.0.2         formatR_1.7            hdf5r_1.3.2           
#>  [25] plotly_4.9.2.1         labeling_0.4.2         Seurat_3.9.9.9008     
#>  [28] scales_1.1.1           lmtest_0.9-38          spatstat.data_1.4-3   
#>  [31] ggridges_0.5.2         pbapply_1.4-3          spatstat_1.64-1       
#>  [34] goftest_1.2-2          stringr_1.4.0          digest_0.6.27         
#>  [37] spatstat.utils_1.17-0  rmarkdown_2.5          pkgconfig_2.0.3       
#>  [40] htmltools_0.5.1.1      highr_0.8              fastmap_1.0.1         
#>  [43] htmlwidgets_1.5.2      rlang_0.4.9            shiny_1.5.0           
#>  [46] farver_2.0.3           generics_0.0.2         zoo_1.8-8             
#>  [49] jsonlite_1.7.2         crosstalk_1.1.0.1      ica_1.0-2             
#>  [52] magrittr_2.0.1         patchwork_1.1.0.9000   Rcpp_1.0.5            
#>  [55] munsell_0.5.0          abind_1.4-5            reticulate_1.16       
#>  [58] lifecycle_0.2.0        stringi_1.5.3          yaml_2.2.1            
#>  [61] MASS_7.3-53            Rtsne_0.15             plyr_1.8.6            
#>  [64] grid_4.0.2             parallel_4.0.2         listenv_0.8.0         
#>  [67] promises_1.1.1         ggrepel_0.8.2          crayon_1.3.4.9000     
#>  [70] deldir_0.1-29          miniUI_0.1.1.1         lattice_0.20-41       
#>  [73] cowplot_1.1.0          splines_4.0.2          tensor_1.5            
#>  [76] pillar_1.4.7           igraph_1.2.6           future.apply_1.6.0    
#>  [79] reshape2_1.4.4         codetools_0.2-16       leiden_0.3.3          
#>  [82] glue_1.4.2             evaluate_0.14          data.table_1.13.2     
#>  [85] vctrs_0.3.5            png_0.1-7              httpuv_1.5.4          
#>  [88] gtable_0.3.0           RANN_2.6.1             purrr_0.3.4           
#>  [91] polyclip_1.10-0        tidyr_1.1.2            future_1.19.1         
#>  [94] xfun_0.19              rsvd_1.0.3             mime_0.9              
#>  [97] xtable_1.8-4           later_1.1.0.1          survival_3.2-3        
#> [100] viridisLite_0.3.0      tibble_3.0.4           cluster_2.1.0         
#> [103] globals_0.13.1         fitdistrplus_1.1-1     ellipsis_0.3.1        
#> [106] ROCR_1.0-11

Runtime

print(proc.time() - tic)
#>     user   system  elapsed 
#> 1502.264   36.831 1549.519